# Copyright (c) Alibaba, Inc. and its affiliates.
import os
from typing import List

from swift.llm import (
    ExportArguments,
    PtEngine,
    RequestConfig,
    Template,
    prepare_model_template,
)
from swift.utils import get_logger

logger = get_logger()


def replace_and_concat(
    template: "Template", template_list: List, placeholder: str, keyword: str
):
    final_str = ""
    for t in template_list:
        if isinstance(t, str):
            final_str += t.replace(placeholder, keyword)
        elif isinstance(t, (tuple, list)):
            if isinstance(t[0], int):
                final_str += template.tokenizer.decode(t)
            else:
                for attr in t:
                    if attr == "bos_token_id":
                        final_str += template.tokenizer.bos_token
                    elif attr == "eos_token_id":
                        final_str += template.tokenizer.eos_token
                    else:
                        raise ValueError(f"Unknown token: {attr}")
    return final_str


def export_to_ollama(args: ExportArguments):
    args.device_map = "meta"  # Accelerate load speed.
    logger.info("Exporting to ollama:")
    os.makedirs(args.output_dir, exist_ok=True)
    model, template = prepare_model_template(args)
    pt_engine = PtEngine.from_model_template(model, template)
    logger.info(f"Using model_dir: {pt_engine.model_dir}")
    template_meta = template.template_meta
    with open(os.path.join(args.output_dir, "Modelfile"), "w", encoding="utf-8") as f:
        f.write(f"FROM {pt_engine.model_dir}\n")
        f.write(
            f'TEMPLATE """{{{{ if .System }}}}'
            f'{replace_and_concat(template, template_meta.system_prefix, "{{SYSTEM}}", "{{ .System }}")}'
            f'{{{{ else }}}}{replace_and_concat(template, template_meta.prefix, "", "")}'
            f"{{{{ end }}}}"
        )
        f.write(
            f"{{{{ if .Prompt }}}}"
            f'{replace_and_concat(template, template_meta.prompt, "{{QUERY}}", "{{ .Prompt }}")}'
            f"{{{{ end }}}}"
        )
        f.write("{{ .Response }}")
        f.write(replace_and_concat(template, template_meta.suffix, "", "") + '"""\n')
        f.write(
            f'PARAMETER stop "{replace_and_concat(template, template_meta.suffix, "", "")}"\n'
        )

        request_config = RequestConfig(
            temperature=args.temperature,
            top_k=args.top_k,
            top_p=args.top_p,
            repetition_penalty=args.repetition_penalty,
        )
        generation_config = pt_engine._prepare_generation_config(request_config)
        pt_engine._add_stop_words(
            generation_config, request_config, template.template_meta
        )
        for stop_word in generation_config.stop_words:
            f.write(f'PARAMETER stop "{stop_word}"\n')
        f.write(f"PARAMETER temperature {generation_config.temperature}\n")
        f.write(f"PARAMETER top_k {generation_config.top_k}\n")
        f.write(f"PARAMETER top_p {generation_config.top_p}\n")
        f.write(f"PARAMETER repeat_penalty {generation_config.repetition_penalty}\n")

    logger.info("Save Modelfile done, you can start ollama by:")
    logger.info("> ollama serve")
    logger.info("In another terminal:")
    logger.info(
        "> ollama create my-custom-model "
        f'-f {os.path.join(args.output_dir, "Modelfile")}'
    )
    logger.info("> ollama run my-custom-model")
